{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Creating, Evaluating, and Deploying a Fraud Detection Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "## Introduction\n",
    "\n",
    "In this notebook, we'll demonstrate data engineering and data science work flow with an e2e sample. The scenario is to build a model for detecting fraud credit card transactions."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "## Step 1: Load the Data\n",
    "\n",
    "The dataset contains transactions made by credit cards in September 2013 by European cardholders.\n",
    "This dataset presents transactions that occurred in two days, where we have 492 frauds out of 284,807 transactions. The dataset is highly unbalanced, the positive class (frauds) account for 0.172% of all transactions.\n",
    "\n",
    "It contains only numerical input variables which are the result of a PCA transformation. Unfortunately, due to confidentiality issues, we cannot provide the original features and more background information about the data. Features V1, V2, … V28 are the principal components obtained with PCA, the only features which have not been transformed with PCA are 'Time' and 'Amount'. Feature 'Time' contains the seconds elapsed between each transaction and the first transaction in the dataset. The feature 'Amount' is the transaction Amount, this feature can be used for example-dependant cost-sensitive learning. Feature 'Class' is the response variable and it takes value 1 in case of fraud and 0 otherwise.\n",
    "\n",
    "Given the class imbalance ratio, we recommend measuring the accuracy using the Area Under the Precision-Recall Curve (AUPRC). Confusion matrix accuracy is not meaningful for unbalanced classification."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "- creditcard.csv\n",
    "\n",
    "|\"Time\"|\"V1\"|\"V2\"|\"V3\"|\"V4\"|\"V5\"|\"V6\"|\"V7\"|\"V8\"|\"V9\"|\"V10\"|\"V11\"|\"V12\"|\"V13\"|\"V14\"|\"V15\"|\"V16\"|\"V17\"|\"V18\"|\"V19\"|\"V20\"|\"V21\"|\"V22\"|\"V23\"|\"V24\"|\"V25\"|\"V26\"|\"V27\"|\"V28\"|\"Amount\"|\"Class\"|\n",
    "|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|\n",
    "|0|-1.3598071336738|-0.0727811733098497|2.53634673796914|1.37815522427443|-0.338320769942518|0.462387777762292|0.239598554061257|0.0986979012610507|0.363786969611213|0.0907941719789316|-0.551599533260813|-0.617800855762348|-0.991389847235408|-0.311169353699879|1.46817697209427|-0.470400525259478|0.207971241929242|0.0257905801985591|0.403992960255733|0.251412098239705|-0.018306777944153|0.277837575558899|-0.110473910188767|0.0669280749146731|0.128539358273528|-0.189114843888824|0.133558376740387|-0.0210530534538215|149.62|\"0\"|\n",
    "|0|1.19185711131486|0.26615071205963|0.16648011335321|0.448154078460911|0.0600176492822243|-0.0823608088155687|-0.0788029833323113|0.0851016549148104|-0.255425128109186|-0.166974414004614|1.61272666105479|1.06523531137287|0.48909501589608|-0.143772296441519|0.635558093258208|0.463917041022171|-0.114804663102346|-0.183361270123994|-0.145783041325259|-0.0690831352230203|-0.225775248033138|-0.638671952771851|0.101288021253234|-0.339846475529127|0.167170404418143|0.125894532368176|-0.00898309914322813|0.0147241691924927|2.69|\"0\"|"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install libraries\n",
    "In this notebook, we'll use `imblearn` which first needs to be installed. The PySpark kernel will be restarted after `%pip install`, thus we need to install it before we run any other cells."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# install imblearn for SMOTE\n",
    "%pip install imblearn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**By defining below parameters, we can apply this notebook on different datasets easily.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "IS_CUSTOM_DATA = False  # if True, dataset has to be uploaded manually\n",
    "\n",
    "TARGET_COL = \"Class\"  # target column name\n",
    "IS_SAMPLE = False  # if True, use only <SAMPLE_ROWS> rows of data for training, otherwise use all data\n",
    "SAMPLE_ROWS = 5000  # if IS_SAMPLE is True, use only this number of rows for training\n",
    "\n",
    "DATA_FOLDER = \"Files/fraud-detection/\"  # folder with data files\n",
    "DATA_FILE = \"creditcard.csv\"  # data file name\n",
    "\n",
    "EXPERIMENT_NAME = \"aisample-fraud\"  # mlflow experiment name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Download dataset and Upload to lakehouse\n",
    "\n",
    "**Please add a lakehouse to the notebook before running it.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "if not IS_CUSTOM_DATA:\n",
    "    # Download demo data files into lakehouse if not exist\n",
    "    import os, requests\n",
    "\n",
    "    remote_url = \"https://synapseaisolutionsa.blob.core.windows.net/public/Credit_Card_Fraud_Detection\"\n",
    "    fname = \"creditcard.csv\"\n",
    "    download_path = f\"/lakehouse/default/{DATA_FOLDER}/raw\"\n",
    "\n",
    "    if not os.path.exists(\"/lakehouse/default\"):\n",
    "        raise FileNotFoundError(\"Default lakehouse not found, please add a lakehouse.\")\n",
    "    os.makedirs(download_path, exist_ok=True)\n",
    "    if not os.path.exists(f\"{download_path}/{fname}\"):\n",
    "        r = requests.get(f\"{remote_url}/{fname}\", timeout=30)\n",
    "        with open(f\"{download_path}/{fname}\", \"wb\") as f:\n",
    "            f.write(r.content)\n",
    "    print(\"Downloaded demo data files into lakehouse.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "# to record the notebook running time\n",
    "import time\n",
    "\n",
    "ts = time.time()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Read data from lakehouse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "df = (\n",
    "    spark.read.format(\"csv\")\n",
    "    .option(\"header\", \"true\")\n",
    "    .option(\"inferSchema\", True)\n",
    "    .load(f\"{DATA_FOLDER}/raw/{DATA_FILE}\")\n",
    "    .cache()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "## Step 2. Exploratory Data Analysis\n",
    "\n",
    "### Display Raw Data\n",
    "\n",
    "We can explore the raw data with `display`, do some basic statistcs or even show chart views."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "# print dataset basic info\n",
    "print(\"records read: \" + str(df.count()))\n",
    "print(\"Schema: \")\n",
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Cast columns into the correct types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "import pyspark.sql.functions as F\n",
    "\n",
    "df_columns = df.columns\n",
    "df_columns.remove(TARGET_COL)\n",
    "\n",
    "# to make sure the TARGET_COL is the last column\n",
    "df = df.select(df_columns + [TARGET_COL]).withColumn(\n",
    "    TARGET_COL, F.col(TARGET_COL).cast(\"int\")\n",
    ")\n",
    "\n",
    "if IS_SAMPLE:\n",
    "    df = df.limit(SAMPLE_ROWS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "## Step 3. Model development and deploy\n",
    "So far we have explored the dataset, checked the schem, adjusted the columns order, and casted the columns into correct types.\n",
    "\n",
    "Next, we'll train a lightgbm model to classify fraud transactions."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Prepare training and testing data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "# Split the dataset into train and test\n",
    "train, test = df.randomSplit([0.85, 0.15], seed=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "# Merge Columns\n",
    "from pyspark.ml.feature import VectorAssembler\n",
    "\n",
    "feature_cols = df.columns[:-1]\n",
    "featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n",
    "train_data = featurizer.transform(train)[TARGET_COL, \"features\"]\n",
    "test_data = featurizer.transform(test)[TARGET_COL, \"features\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Check data volumn and imbalance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "display(train_data.groupBy(TARGET_COL).count())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Handle imbalance data\n",
    "We'll apply [SMOTE](https://arxiv.org/abs/1106.1813) (Synthetic Minority Over-sampling Technique) to automatically handle imbalance data. A dataset is imbalanced if the classification categories are not approximately equally represented. Often real-world data sets are predominately composed of \"normal\" examples with only a small percentage of \"abnormal\" or \"interesting\" examples. It is also the case that the cost of misclassifying an abnormal (interesting) example as a normal example is often much higher than the cost of the reverse error. Under-sampling of the majority (normal) class has been proposed as a good means of increasing the sensitivity of a classifier to the minority class. This paper shows that a combination of our method of over-sampling the minority (abnormal) class and under-sampling the majority (normal) class can achieve better classifier performance (in ROC space) than only under-sampling the majority class. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "#### Apply SMOTE for new train_data \n",
    "imblearn only works for pandas dataframe, not pyspark dataframe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.functions import vector_to_array, array_to_vector\n",
    "import numpy as np\n",
    "from collections import Counter\n",
    "from imblearn.over_sampling import SMOTE\n",
    "\n",
    "train_data_array = train_data.withColumn(\"features\", vector_to_array(\"features\"))\n",
    "\n",
    "train_data_pd = train_data_array.toPandas()\n",
    "\n",
    "X = train_data_pd[\"features\"].to_numpy()\n",
    "y = train_data_pd[TARGET_COL].to_numpy()\n",
    "print(\"Original dataset shape %s\" % Counter(y))\n",
    "\n",
    "X = np.array([np.array(x) for x in X])\n",
    "\n",
    "sm = SMOTE(random_state=42)\n",
    "X_res, y_res = sm.fit_resample(X, y)\n",
    "print(\"Resampled dataset shape %s\" % Counter(y_res))\n",
    "\n",
    "new_train_data = tuple(zip(X_res.tolist(), y_res.tolist()))\n",
    "dataColumns = [\"features\", TARGET_COL]\n",
    "new_train_data = spark.createDataFrame(data=new_train_data, schema=dataColumns)\n",
    "new_train_data = new_train_data.withColumn(\"features\", array_to_vector(\"features\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Define the Model\n",
    "\n",
    "With our data in place, we can now define the model. We'll apply lightgbm model in this notebook. \n",
    "\n",
    "We'll leverage SynapseML to implement the model within a few lines of code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "from synapse.ml.lightgbm import LightGBMClassifier\n",
    "\n",
    "model = LightGBMClassifier(\n",
    "    objective=\"binary\", featuresCol=\"features\", labelCol=TARGET_COL, isUnbalance=True\n",
    ")\n",
    "smote_model = LightGBMClassifier(\n",
    "    objective=\"binary\", featuresCol=\"features\", labelCol=TARGET_COL, isUnbalance=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Model training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "model = model.fit(train_data)\n",
    "smote_model = smote_model.fit(new_train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Model Explanation\n",
    "Here we can show the importance of each column."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "feature_importances = model.getFeatureImportances()\n",
    "fi = pd.Series(feature_importances, index=feature_cols)\n",
    "fi = fi.sort_values(ascending=True)\n",
    "f_index = fi.index\n",
    "f_values = fi.values\n",
    "\n",
    "# print feature importances\n",
    "print(\"f_index:\", f_index)\n",
    "print(\"f_values:\", f_values)\n",
    "\n",
    "# plot\n",
    "x_index = list(range(len(fi)))\n",
    "x_index = [x / len(fi) for x in x_index]\n",
    "plt.rcParams[\"figure.figsize\"] = (20, 20)\n",
    "plt.barh(\n",
    "    x_index, f_values, height=0.028, align=\"center\", color=\"tan\", tick_label=f_index\n",
    ")\n",
    "plt.xlabel(\"importances\")\n",
    "plt.ylabel(\"features\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Model Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "predictions = model.transform(test_data)\n",
    "predictions.limit(10).toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "from synapse.ml.train import ComputeModelStatistics\n",
    "\n",
    "metrics = ComputeModelStatistics(\n",
    "    evaluationMetric=\"classification\", labelCol=TARGET_COL, scoredLabelsCol=\"prediction\"\n",
    ").transform(predictions)\n",
    "display(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "# collect confusion matrix value\n",
    "cm = metrics.select(\"confusion_matrix\").collect()[0][0].toArray()\n",
    "print(cm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "# plot confusion matrix\n",
    "import seaborn as sns\n",
    "\n",
    "sns.set(rc={\"figure.figsize\": (6, 4.5)})\n",
    "ax = sns.heatmap(cm, annot=True, fmt=\".20g\")\n",
    "ax.set_title(\"Confusion Matrix\")\n",
    "ax.set_xlabel(\"Predicted label\")\n",
    "ax.set_ylabel(\"True label\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
    "\n",
    "\n",
    "def evaluate(predictions):\n",
    "    \"\"\"\n",
    "    Evaluate the model by computing AUROC and AUPRC with the predictions.\n",
    "    \"\"\"\n",
    "\n",
    "    # initialize the binary evaluator\n",
    "    evaluator = BinaryClassificationEvaluator(\n",
    "        rawPredictionCol=\"prediction\", labelCol=TARGET_COL\n",
    "    )\n",
    "\n",
    "    _evaluator = lambda metric: evaluator.setMetricName(metric).evaluate(predictions)\n",
    "\n",
    "    # calculate AUROC, baseline 0.5\n",
    "    auroc = _evaluator(\"areaUnderROC\")\n",
    "    print(f\"AUROC: {auroc:.4f}\")\n",
    "\n",
    "    # calculate AUPRC, baseline positive rate (0.172% in the demo data)\n",
    "    auprc = _evaluator(\"areaUnderPR\")\n",
    "    print(f\"AUPRC: {auprc:.4f}\")\n",
    "\n",
    "    return auroc, auprc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# evaluate the original model\n",
    "auroc, auprc = evaluate(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "# evaluate the SMOTE model\n",
    "new_predictions = smote_model.transform(test_data)\n",
    "new_auroc, new_auprc = evaluate(new_predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "if new_auprc > auprc:\n",
    "    # Using model trained on SMOTE data if it has higher AUPRC\n",
    "    model = smote_model\n",
    "    auprc = new_auprc\n",
    "    auroc = new_auroc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Log and Load Model with MLFlow\n",
    "Now we get a pretty good model, we can save it for later use. Here we use mlflow to log metrics/models, and load models back for prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "# setup mlflow\n",
    "import mlflow\n",
    "import trident.mlflow\n",
    "from trident.mlflow import get_sds_url\n",
    "\n",
    "mlflow.set_tracking_uri(get_sds_url())\n",
    "mlflow.set_registry_uri(get_sds_url())\n",
    "mlflow.set_experiment(EXPERIMENT_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# log model, metrics and params\n",
    "with mlflow.start_run() as run:\n",
    "    print(\"log model:\")\n",
    "    mlflow.spark.log_model(\n",
    "        model,\n",
    "        f\"{EXPERIMENT_NAME}-lightgbm\",\n",
    "        registered_model_name=f\"{EXPERIMENT_NAME}-lightgbm\",\n",
    "        dfs_tmpdir=\"Files/spark\",\n",
    "    )\n",
    "\n",
    "    print(\"log metrics:\")\n",
    "    mlflow.log_metrics({\"AUPRC\": auprc, \"AUROC\": auroc})\n",
    "\n",
    "    print(\"log parameters:\")\n",
    "    mlflow.log_params({\"DATA_FILE\": DATA_FILE})\n",
    "\n",
    "    model_uri = f\"runs:/{run.info.run_id}/{EXPERIMENT_NAME}-lightgbm\"\n",
    "    print(\"Model saved in run %s\" % run.info.run_id)\n",
    "    print(f\"Model URI: {model_uri}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model back\n",
    "loaded_model = mlflow.spark.load_model(model_uri, dfs_tmpdir=\"Files/spark\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "## Step 4. Save Prediction Results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "source": [
    "### Model Deploy and Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "batch_predictions = loaded_model.transform(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "# code for saving predictions into lakehouse\n",
    "batch_predictions.write.format(\"delta\").mode(\"overwrite\").save(\n",
    "    f\"{DATA_FOLDER}/predictions/batch_predictions\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "print(f\"Full run cost {int(time.time() - ts)} seconds.\")"
   ]
  }
 ],
 "metadata": {
  "kernel_info": {
   "name": "synapse_pyspark"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "notebook_environment": {},
  "save_output": true,
  "spark_compute": {
   "compute_id": "/trident/default",
   "session_options": {
    "conf": {
     "spark.dynamicAllocation.enabled": "false",
     "spark.dynamicAllocation.maxExecutors": "2",
     "spark.dynamicAllocation.minExecutors": "2",
     "spark.livy.synapse.ipythonInterpreter.enabled": "true"
    },
    "driverCores": 8,
    "driverMemory": "56g",
    "enableDebugMode": false,
    "executorCores": 8,
    "executorMemory": "56g",
    "keepAliveTimeout": 30,
    "numExecutors": 5
   }
  },
  "synapse_widget": {
   "state": {},
   "version": "0.1"
  },
  "trident": {
   "lakehouse": {
    "default_lakehouse": "",
    "known_lakehouses": [
     {
      "id": ""
     }
    ]
   }
  },
  "vscode": {
   "interpreter": {
    "hash": "8cebba326b76ca708172f0a6a24a89689a3b64f83dbd9353b827f2f4b33d3f80"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
